-- «•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»«•»

{-
 * n-body problem
 * based on http://shootout.alioth.debian.org/gp4/benchmark.php?test=nbody&lang=java&id=0
 *
 Is slower than java by a factor of more than ten, the cause of which is the
 slowdown introduced trough non inlined ST implementation.
-}


package examples.FBodies 
        inline (sqr, Body.x, Body.y, Body.z, Body.vx, Body.vz, Body.vy, Body.mass)
    where


-- import Data.TreeMap  -- Show instance of list
import frege.prelude.Math (pi, sqrt)


data Body = ! Body {x::Double, y::Double, z::Double,
                vx::Double, vy::Double, vz::Double,
                mass::Double }

derive Show Body

{-
data Vector = native [] Double where
        fromList [] = Vector.new 0
        fromList is =
            let
                ilen = is.length
                iarr = Vector.new ilen
                loop (iarr::Vector) j (x:xs) = do
                        void <- iarr.[j <- x]
                    for loop iarr (j+1) xs
                loop (iarr::Vector) _ []     = iarr
            in loop iarr 0 is

// name the indexes
x = 0; y = 1; z = 2; vx = 3; vy = 4; vz = 5; mass = 6
-}

derive ArrayElement Body
-- type Vector = Body
type Bodies s = ArrayOf s Body

-- some fundamental constants
-- pix = 3.141592653589793
solar_mass = 4.0 * pi * pi
days_per_year = 365.24


offsetMomentum (body::Body) px py pz = Body {
    x=body.x, y=body.y, z=body.z, mass=body.mass,
    vx = negate px / solar_mass,
    vy = negate py / solar_mass,
    vz = negate pz / solar_mass }


initSystem :: ST s (Bodies s)
initSystem = do
        bs <- bodies
        init 0.0 0.0 0.0 0 bs
        return bs
    where
        init :: Double -> Double -> Double -> Int -> Bodies s -> ST s ()
        init !px !py !pz n bs = do
            len <- bs.getLength
            case len of
                _ | n < len = do
                        b <- getElemAt bs n
                        init (px + b.vx * b.mass)
                            (py + b.vy * b.mass)
                            (pz + b.vz * b.mass) (n+1) bs
                  | otherwise = do
                        b <- getElemAt bs 0
                        setElemAt bs 0 (offsetMomentum b px py pz)
        -- bodies :: ST s (Bodies s)
        bodies = JArray.genericFromList [sun, jupiter, saturn, uranus, neptun] where
            sun = Body 0.0  0.0  0.0  0.0  0.0  0.0  solar_mass

            jupiter = Body
                           4.84143144246472090e+00
                           (negate 1.16032004402742839e+00)
                           (negate 1.03622044471123109e-01)
                           (1.66007664274403694e-03 * days_per_year)
                           (7.69901118419740425e-03 * days_per_year)
                           (negate 6.90460016972063023e-05 * days_per_year)
                           (9.54791938424326609e-04 * solar_mass)

            saturn = Body
                        8.34336671824457987e+00
                        4.12479856412430479e+00
                        (negate 4.03523417114321381e-01)
                        (negate 2.76742510726862411e-03 * days_per_year)
                        (4.99852801234917238e-03 * days_per_year)
                        (2.30417297573763929e-05 * days_per_year)
                        (2.85885980666130812e-04 * solar_mass)

            uranus = Body
                        1.28943695621391310e+01
                        (negate 1.51111514016986312e+01)
                        (negate 2.23307578892655734e-01)
                        (2.96460137564761618e-03 * days_per_year)
                        (2.37847173959480950e-03 * days_per_year)
                        (negate 2.96589568540237556e-05 * days_per_year)
                        (4.36624404335156298e-05 * solar_mass)

            neptun = Body
                        1.53796971148509165e+01
                        (negate 2.59193146099879641e+01)
                        1.79258772950371181e-01
                        (2.68067772490389322e-03 * days_per_year)
                        (1.62824170038242295e-03 * days_per_year)
                        (negate 9.51592254519715870e-05 * days_per_year)
                        (5.15138902046611451e-05 * solar_mass)

{-
    public double energy(){
        double dx, dy, dz, distance;
        double e = 0.0;

        for (int i=0; i < bodies.length; ++i) {
            e += 0.5 * bodies[i].mass *
               ( bodies[i].vx * bodies[i].vx
               + bodies[i].vy * bodies[i].vy
               + bodies[i].vz * bodies[i].vz );

            for (int j=i+1; j < bodies.length; ++j) {
                dx = bodies[i].x - bodies[j].x;
                dy = bodies[i].y - bodies[j].y;
                dz = bodies[i].z - bodies[j].z;

                distance = Math.sqrt(dx*dx + dy*dy + dz*dz);
                e -= (bodies[i].mass * bodies[j].mass) / distance;
            }
        }
        return e;
    }
-}

sqr :: Double -> Double
sqr x = x*x
-- pure native sqrt java.lang.Math.sqrt :: Double -> Double

energy :: JArray Body -> Double
energy !bs = fold outer 0.0 [0..len-1] where
    -- outer :: Int -> Double -> ST s Double
    !len = bs.length
    outer !e !i = fold inner (e' bi) [i+1 .. len-1] 
      where
        bi = elemAt bs i
        e' :: Body -> Double
        e' bi = e + 0.5 * bi.mass * (sqr bi.vx + sqr bi.vy + sqr bi.vz)
        -- inner :: Int -> Double -> Body -> ST s Double
        inner !e !j = e - bi.mass * bj.mass / distance
            where
                bj = elemAt bs j
                !dx = bi.x - bj.x
                !dy = bi.y - bj.y
                !dz = bi.z - bj.z
                !distance = sqrt (dx*dx + dy*dy + dz*dz)
                

-- dt = 0.01

zeig s b = traceLn (s ++ ": " ++ (Body.show b))

advance :: Bodies s -> ST s ()
advance bs = do
        len <- bs.getLength
        sequence_ [ inner i j | i <- [0 .. len-2], j <- [i+1..len-1]] 
        -- modifyArray move bs
        sequence_ [ getElemAt bs i >>= setElemAt bs i . move | i <- [0..len-1]]
    where
        move :: Body -> Body
        move b = Body {mass=b.mass, vx=b.vx, vy=b.vy, vz=b.vz,
                    x = b.x + (0.01 * b.vx),
                    y = b.y + (0.01 * b.vy),
                    z = b.z + (0.01 * b.vz)}
            

        -- inner :: Bodies inn -> Int -> Int -> ST inn ()
        inner !i !j = do
            bi <- getElemAt bs i
            bj <- getElemAt bs j
            let
                    !dx = bi.x - bj.x
                    !dy = bi.y - bj.y
                    !dz = bi.z - bj.z
                    !distance = sqrt (dx*dx + dy*dy + dz*dz)
                    !mag = 0.01 / (distance * distance * distance)
                    !nbi = Body {vx = bi.vx - (dx * bj.mass * mag),
                                 vy = bi.vy - (dy * bj.mass * mag),
                                 vz = bi.vz - (dz * bj.mass * mag),
                                 mass = bi.mass, x = bi.x, y = bi.y, z = bi.z}
                    !nbj = Body {vx = bj.vx + (dx * bi.mass * mag),
                                 vy = bj.vy + (dy * bi.mass * mag),
                                 vz = bj.vz + (dz * bi.mass * mag),
                                 mass = bj.mass, x = bj.x, y = bj.y, z = bj.z}
            setElemAt bs i nbi
            setElemAt bs j nbj
        
            

loop !n !bs | n > 0 = advance bs >> loop (n-1) bs
            | otherwise = return ()

data NumberFormat = native java.text.NumberFormat where
    native mk java.text.NumberFormat.getInstance :: () -> ST s (Mutable s NumberFormat)
    native setMaximumFractionDigits :: Mutable s NumberFormat -> Int -> ST s ()
    native setMinimumFractionDigits :: Mutable s NumberFormat -> Int -> ST s ()
    native setGroupingUsed :: Mutable s NumberFormat -> Bool -> ST s ()
    native format :: Mutable s NumberFormat -> Double -> ST s String


main (sn:_)
    | Right !n <- sn.int = do
            nf <- NumberFormat.mk ()
            nf.setMaximumFractionDigits(9)
            nf.setMinimumFractionDigits(9)
            nf.setGroupingUsed(false)
            bodies <- initSystem
            stderr.println bodies.getClass.getName
            e1     <- readonly energy bodies
            (nf.format e1) >>= println
            loop n bodies
            e2     <- readonly energy bodies
            (nf.format e2) >>= println 
            return ()
main _ = stderr.println "usage: java examples.FBodies steps"
